{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Process Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Process Commands"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# TODO: change to your own project path!!!\n",
    "OPEN_SORA_HOME = \"/path/to/Open-Sora/\"\n",
    "\n",
    "\n",
    "def convert_dataset_cmd(input_dir, output_file, datatype=\"video\"):\n",
    "    commands = []\n",
    "    commands.append(f'echo \"Converting {input_dir} to {output_file}\"')\n",
    "    output_dir = os.path.dirname(output_file)\n",
    "\n",
    "    commands.append(f\"mkdir -p {output_dir}\")\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"python -m tools.datasets.convert {datatype} {input_dir} --output {output_file}\")\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_video_info(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_info{ext}\"\n",
    "    output_format = ext[1:]\n",
    "\n",
    "    commands.append(f'echo \"Getting info of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(\n",
    "        f\"python -m tools.datasets.datautil {input_file} --output {output_file} --format {output_format} --info --fmin 1\"\n",
    "    )\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_video_info_torchvision(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_info{ext}\"\n",
    "    output_format = ext[1:]\n",
    "\n",
    "    commands.append(f'echo \"Getting info of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(\n",
    "        f\"python -m tools.datasets.datautil {input_file} --output {output_file} --format {output_format} --video-info --fmin 1\"\n",
    "    )\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_caption_llava7b_video(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_caption{ext}\"\n",
    "    output_format = ext[1:]\n",
    "\n",
    "    commands.append(f'echo \"Getting info of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"conda activate llava2\")\n",
    "    commands.append(\n",
    "        f\"torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava {input_file} --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video\"\n",
    "    )\n",
    "    commands.append(f\"conda activate opensora\")\n",
    "    commands.append(\n",
    "        f\"python -m tools.datasets.datautil {base}_caption_part*{ext} --output {output_file} --format {output_format} --intersection {input_file} --clean-caption --refine-llm-caption --remove-empty-caption\"\n",
    "    )\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_caption_load(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_caption{ext}\"\n",
    "    output_format = ext[1:]\n",
    "\n",
    "    commands.append(f'echo \"Getting caption of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(\n",
    "        f\"python -m tools.datasets.datautil {input_file} --output {output_file} --format {output_format} --load-caption json --remove-empty-caption --clean-caption\"\n",
    "    )\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_aesthetic_score(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_aes{ext}\"\n",
    "    output_format = ext[1:]\n",
    "\n",
    "    commands.append(f'echo \"Getting aesthetic score of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference {input_file}\")\n",
    "    commands.append(\n",
    "        f\"python -m tools.datasets.datautil {base}_aes_part*{ext} --output {output_file} --format {output_format} --sort aes\"\n",
    "    )\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_flow_score(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_flow{ext}\"\n",
    "\n",
    "    commands.append(f'echo \"Getting flow score of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference {input_file}\")\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_ocr(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_match{ext}\"\n",
    "\n",
    "    commands.append(f'echo \"Getting match score of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"torchrun --standalone --nproc_per_node 8 -m tools.scoring.ocr.inference {input_file}\")\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "    \n",
    "def get_match_score(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_match{ext}\"\n",
    "\n",
    "    commands.append(f'echo \"Getting match score of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference {input_file}\")\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_cmotion_score(input_file):\n",
    "    commands = []\n",
    "    base, ext = os.path.splitext(input_file)\n",
    "    output_file = f\"{base}_cmotion{ext}\"\n",
    "\n",
    "    commands.append(f'echo \"Getting cmotion score of {input_file} to {output_file}\"')\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append(f\"python -m tools.caption.camera_motion_detect {input_file}\")\n",
    "    return \" && \".join(commands), output_file\n",
    "\n",
    "\n",
    "def get_commands(job_list):\n",
    "    commands = []\n",
    "    output_file = None\n",
    "    for job in job_list:\n",
    "        cmd = job.pop(\"cmd\")\n",
    "        if output_file is None:\n",
    "            command, output_file = cmd(**job)\n",
    "            commands.append(command)\n",
    "        else:\n",
    "            job[\"input_file\"] = output_file\n",
    "            command, output_file = cmd(**job)\n",
    "            commands.append(command)\n",
    "    commands.append(f'echo \"All Done!\"')\n",
    "    return \" && \".join(commands), output_file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Remote Launch via Paramiko\n",
    "\n",
    "First, add hosts to `~/.ssh/config`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paramiko\n",
    "\n",
    "HOSTS = [\"host-0\", \"host-1\", \"host-2\", \"host-3\", \"host-4\", \"host-5\", \"host-6\", \"host-7\"]\n",
    "\n",
    "# load from ~/.ssh/config\n",
    "ssh_config = paramiko.SSHConfig()\n",
    "user_config_file = os.path.expanduser(\"~/.ssh/config\")\n",
    "if os.path.exists(user_config_file):\n",
    "    with open(user_config_file) as f:\n",
    "        ssh_config.parse(f)\n",
    "\n",
    "\n",
    "def get_ssh_config(hostname):\n",
    "    # get the configuration for the host\n",
    "    user_config = ssh_config.lookup(hostname)\n",
    "    cfg = {\n",
    "        \"hostname\": user_config[\"hostname\"],\n",
    "        \"username\": user_config[\"user\"],\n",
    "        \"port\": int(user_config[\"port\"]),\n",
    "        \"key_filename\": user_config[\"identityfile\"],\n",
    "    }\n",
    "    return cfg\n",
    "\n",
    "\n",
    "def connect(hostname):\n",
    "    cfg = get_ssh_config(hostname)\n",
    "    # connect\n",
    "    client = paramiko.SSHClient()\n",
    "    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())\n",
    "    client.connect(**cfg)\n",
    "    return client\n",
    "\n",
    "\n",
    "def run_command(command, hostname, nohup=False, log_file=None, sleep=None):\n",
    "    client = connect(hostname)\n",
    "    print(\"HOST:\", hostname)\n",
    "    if sleep:\n",
    "        command = f\"sleep {sleep}; {command}\"\n",
    "    command = f\"bash -ic '{command}'\"\n",
    "    if log_file:\n",
    "        command = f\"{command} >> {log_file} 2>&1\"\n",
    "    if nohup:\n",
    "        command = f\"nohup {command} &\"\n",
    "    print(\"COMMAND:\", command)\n",
    "    stdin, stdout, stderr = client.exec_command(command, get_pty=False)\n",
    "\n",
    "    stdout_str = stdout.read().decode()\n",
    "    stderr_str = stderr.read().decode()\n",
    "    if stdout_str:\n",
    "        print(\"==== STDOUT ====\\n\", stdout_str)\n",
    "    if stderr_str:\n",
    "        print(\"==== STDERR ====\\n\", stderr_str)\n",
    "\n",
    "    client.close()\n",
    "\n",
    "\n",
    "def run_command_all_hosts(command, hosts=HOSTS):\n",
    "    for hostname in hosts:\n",
    "        run_command(command, hostname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are tools to examine machine's status."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nvidia_smi(host):\n",
    "    if host:\n",
    "        run_command(\"nvidia-smi\", host)\n",
    "    else:\n",
    "        run_command_all_hosts(\"nvidia-smi\")\n",
    "\n",
    "\n",
    "def nvitop(host=None):\n",
    "    if host:\n",
    "        run_command(f\"/home/user/.local/bin/nvitop -1\", host)\n",
    "    else:\n",
    "        run_command_all_hosts(\"/home/user/.local/bin/nvitop -1\")\n",
    "\n",
    "\n",
    "def ps(host=None, interest=\"python|sleep|torchrun|colossal\", all=True):\n",
    "    cmd = \"ps aux\" if all else \"ps ux\"\n",
    "    if host:\n",
    "        if interest is None:\n",
    "            run_command(f\"{cmd} | cat\", host)\n",
    "        else:\n",
    "            run_command(f'{cmd} | cat | grep --color=never -E \"{interest}\"', host)\n",
    "    else:\n",
    "        if interest is None:\n",
    "            run_command_all_hosts(f\"{cmd} | cat\")\n",
    "        else:\n",
    "            run_command_all_hosts(f'{cmd} | cat | grep --color=never -E \"{interest}\"')\n",
    "\n",
    "\n",
    "def kill(pid, host):\n",
    "    run_command(f\"kill -KILL {pid}\", host)\n",
    "\n",
    "\n",
    "def pkill(interest, host):\n",
    "    run_command(f'pkill -9 -f \"{interest}\"', host)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example\n",
    "\n",
    "Remote launch via paramiko."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep = None\n",
    "run_command(cmd, host, log_file=log_file, nohup=True, sleep=sleep)\n",
    "ps(host)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using following commands to monitor the status of the jobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ps()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nvitop(host)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kill(, host)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def colossal_run(data_path, load_path=None):\n",
    "    commands = []\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    command = f\"colossalai run --nproc_per_node 8 --hostfile hostfile scripts/train.py configs/opensora-v1-1/train/video.py --wandb True --data-path {data_path}\"\n",
    "    if load_path:\n",
    "        command = f\"{command} --load-path {load_path}\"\n",
    "    commands.append(command)\n",
    "    cmd = \" && \".join(commands)\n",
    "    return cmd\n",
    "\n",
    "\n",
    "def kill_all():\n",
    "    commands = []\n",
    "    commands.append(f\"cd {OPEN_SORA_HOME}\")\n",
    "    commands.append('cat hostfile  | xargs -I \"{}\" ssh \"{}\" pkill -9 python')\n",
    "    cmd = \" && \".join(commands)\n",
    "    return cmd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "host = \"host-0\"\n",
    "log_file = os.path.join(OPEN_SORA_HOME, \"logs/train.log\")\n",
    "data_path = \"/path/to/meta.csv\"\n",
    "cmd = colossal_run(data_path)\n",
    "print(cmd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_command(cmd, host, log_file=log_file, nohup=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmd = kill_all()\n",
    "run_command(cmd, host)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
